{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optional Lab - Multi-class Classification\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.1 Goals\n",
    "In this lab, you will explore an example of multi-class classification using neural networks.\n",
    "<figure>\n",
    " <img src=\"./images/C2_W2_mclass_header.png\"   style=\"width500px;height:200px;\">\n",
    "</figure>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.2 Tools\n",
    "You will use some plotting routines. These are stored in `lab_utils_multiclass_TF.py` in this directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib widget\n",
    "from sklearn.datasets import make_blobs\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "np.set_printoptions(precision=2)\n",
    "from lab_utils_multiclass_TF import *\n",
    "import logging\n",
    "logging.getLogger(\"tensorflow\").setLevel(logging.ERROR)\n",
    "tf.autograph.set_verbosity(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 2.0 Multi-class Classification\n",
    "Neural Networks are often used to classify data. Examples are neural networks:\n",
    "- take in photos and classify subjects in the photos as {dog,cat,horse,other}\n",
    "- take in a sentence and classify the 'parts of speech' of its elements: {noun, verb, adjective etc..}  \n",
    "\n",
    "A network of this type will have multiple units in its final layer. Each output is associated with a category. When an input example is applied to the network, the output with the highest value is the category predicted. If the output is applied to a softmax function, the output of the softmax will provide probabilities of the input being in each category. \n",
    "\n",
    "In this lab you will see an example of a building a multiclass network in Tensorflow. We will then take a look at how the neural network makes its predictions.\n",
    "\n",
    "Let's start by creating a four-class data set."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 Prepare and visualize our data\n",
    "We will use Scikit-Learn `make_blobs` function to make a training data set with 4 categories as shown in the plot below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make 4-class dataset for classification\n",
    "classes = 4\n",
    "m = 100\n",
    "centers = [[-5, 2], [-2, -2], [1, 2], [5, -2]]\n",
    "std = 1.0\n",
    "X_train, y_train = make_blobs(n_samples=m, centers=centers, cluster_std=std,random_state=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c310768015c5481596e10c68d936f66c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt_mc(X_train,y_train,classes, centers, std=std)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each dot represents a training example. The axis (x0,x1) are the inputs and the color represents the class the example is associated with. Once trained, the model will be presented with a new example, (x0,x1), and will predict the class.  \n",
    "\n",
    "While generated, this data set is representative of many real-world classification problems. There are several input features (x0,...,xn) and several output categories. The model is trained to use the input features to predict the correct output category."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "unique classes [0 1 2 3]\n",
      "class representation [3 3 3 0 3 3 3 3 2 0]\n",
      "shape of X_train: (100, 2), shape of y_train: (100,)\n"
     ]
    }
   ],
   "source": [
    "# show classes in data set\n",
    "print(f\"unique classes {np.unique(y_train)}\")\n",
    "# show how classes are represented\n",
    "print(f\"class representation {y_train[:10]}\")\n",
    "# show shapes of our dataset\n",
    "print(f\"shape of X_train: {X_train.shape}, shape of y_train: {y_train.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2.2 Model\n",
    "<img align=\"Right\" src=\"./images/C2_W2_mclass_lab_network.PNG\"  style=\" width:350px; padding: 10px 20px ; \">\n",
    "This lab will use a 2-layer network as shown.\n",
    "Unlike the binary classification networks, this network has four outputs, one for each class. Given an input example, the output with the highest value is the predicted class of the input.   \n",
    "\n",
    "Below is an example of how to construct this network in Tensorflow. Notice the output layer uses a `linear` rather than a `softmax` activation. While it is possible to include the softmax in the output layer, it is more numerically stable if linear outputs are passed to the loss function during training. If the model is used to predict probabilities, the softmax can be applied at that point."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.random.set_seed(1234)  # applied to achieve consistent results\n",
    "model = Sequential(\n",
    "    [\n",
    "        Dense(2, activation = 'relu',   name = \"L1\"),\n",
    "        Dense(4, activation = 'linear', name = \"L2\")\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The statements below compile and train the network. Setting `from_logits=True` as an argument to the loss function specifies that the output activation was linear rather than a softmax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.8158\n",
      "Epoch 2/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.6976\n",
      "Epoch 3/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.5989\n",
      "Epoch 4/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.5179\n",
      "Epoch 5/200\n",
      "4/4 [==============================] - 0s 988us/step - loss: 1.4369\n",
      "Epoch 6/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.3756\n",
      "Epoch 7/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.3154\n",
      "Epoch 8/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.2621\n",
      "Epoch 9/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.2188\n",
      "Epoch 10/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.1791\n",
      "Epoch 11/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.1446\n",
      "Epoch 12/200\n",
      "4/4 [==============================] - 0s 997us/step - loss: 1.1129\n",
      "Epoch 13/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.0827\n",
      "Epoch 14/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.0516\n",
      "Epoch 15/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 1.0225\n",
      "Epoch 16/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.9967\n",
      "Epoch 17/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.9681\n",
      "Epoch 18/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.9392\n",
      "Epoch 19/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.9092\n",
      "Epoch 20/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.8771\n",
      "Epoch 21/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.8461\n",
      "Epoch 22/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.8099\n",
      "Epoch 23/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.7771\n",
      "Epoch 24/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.7485\n",
      "Epoch 25/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.7215\n",
      "Epoch 26/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.6967\n",
      "Epoch 27/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.6742\n",
      "Epoch 28/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.6540\n",
      "Epoch 29/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.6352\n",
      "Epoch 30/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.6187\n",
      "Epoch 31/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.6030\n",
      "Epoch 32/200\n",
      "4/4 [==============================] - 0s 2ms/step - loss: 0.5884\n",
      "Epoch 33/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5746\n",
      "Epoch 34/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5621\n",
      "Epoch 35/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5512\n",
      "Epoch 36/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5414\n",
      "Epoch 37/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5323\n",
      "Epoch 38/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5236\n",
      "Epoch 39/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5150\n",
      "Epoch 40/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5072\n",
      "Epoch 41/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.5006\n",
      "Epoch 42/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4944\n",
      "Epoch 43/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4888\n",
      "Epoch 44/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4830\n",
      "Epoch 45/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4775\n",
      "Epoch 46/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4725\n",
      "Epoch 47/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4673\n",
      "Epoch 48/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4624\n",
      "Epoch 49/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4574\n",
      "Epoch 50/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4530\n",
      "Epoch 51/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4491\n",
      "Epoch 52/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4451\n",
      "Epoch 53/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4414\n",
      "Epoch 54/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4374\n",
      "Epoch 55/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4336\n",
      "Epoch 56/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4295\n",
      "Epoch 57/200\n",
      "4/4 [==============================] - 0s 999us/step - loss: 0.4261\n",
      "Epoch 58/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4225\n",
      "Epoch 59/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4193\n",
      "Epoch 60/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4161\n",
      "Epoch 61/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.4131\n",
      "Epoch 62/200\n",
      "4/4 [==============================] - 0s 994us/step - loss: 0.4098\n",
      "Epoch 63/200\n",
      "4/4 [==============================] - 0s 975us/step - loss: 0.4067\n",
      "Epoch 64/200\n",
      "4/4 [==============================] - 0s 997us/step - loss: 0.4029\n",
      "Epoch 65/200\n",
      "4/4 [==============================] - 0s 991us/step - loss: 0.3994\n",
      "Epoch 66/200\n",
      "4/4 [==============================] - 0s 2ms/step - loss: 0.3957\n",
      "Epoch 67/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3920\n",
      "Epoch 68/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3878\n",
      "Epoch 69/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3837\n",
      "Epoch 70/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3792\n",
      "Epoch 71/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3755\n",
      "Epoch 72/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3718\n",
      "Epoch 73/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3683\n",
      "Epoch 74/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3643\n",
      "Epoch 75/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3600\n",
      "Epoch 76/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3550\n",
      "Epoch 77/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3491\n",
      "Epoch 78/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3425\n",
      "Epoch 79/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3367\n",
      "Epoch 80/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3293\n",
      "Epoch 81/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3228\n",
      "Epoch 82/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3156\n",
      "Epoch 83/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.3080\n",
      "Epoch 84/200\n",
      "4/4 [==============================] - 0s 2ms/step - loss: 0.3006\n",
      "Epoch 85/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2933\n",
      "Epoch 86/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2864\n",
      "Epoch 87/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2792\n",
      "Epoch 88/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2720\n",
      "Epoch 89/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2645\n",
      "Epoch 90/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2570\n",
      "Epoch 91/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2498\n",
      "Epoch 92/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2432\n",
      "Epoch 93/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2354\n",
      "Epoch 94/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2274\n",
      "Epoch 95/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.2194\n",
      "Epoch 96/200\n",
      "4/4 [==============================] - 0s 992us/step - loss: 0.2127\n",
      "Epoch 97/200\n",
      "4/4 [==============================] - 0s 997us/step - loss: 0.2060\n",
      "Epoch 98/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1995\n",
      "Epoch 99/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1950\n",
      "Epoch 100/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1894\n",
      "Epoch 101/200\n",
      "4/4 [==============================] - 0s 2ms/step - loss: 0.1850\n",
      "Epoch 102/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1804\n",
      "Epoch 103/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1758\n",
      "Epoch 104/200\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1709\n",
      "Epoch 105/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1662\n",
      "Epoch 106/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1616\n",
      "Epoch 107/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1575\n",
      "Epoch 108/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1527\n",
      "Epoch 109/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1480\n",
      "Epoch 110/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1439\n",
      "Epoch 111/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1396\n",
      "Epoch 112/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1357\n",
      "Epoch 113/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1315\n",
      "Epoch 114/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1277\n",
      "Epoch 115/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1240\n",
      "Epoch 116/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1207\n",
      "Epoch 117/200\n",
      "4/4 [==============================] - 0s 984us/step - loss: 0.1171\n",
      "Epoch 118/200\n",
      "4/4 [==============================] - 0s 997us/step - loss: 0.1139\n",
      "Epoch 119/200\n",
      "4/4 [==============================] - 0s 2ms/step - loss: 0.1110\n",
      "Epoch 120/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1084\n",
      "Epoch 121/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1058\n",
      "Epoch 122/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1029\n",
      "Epoch 123/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.1001\n",
      "Epoch 124/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0975\n",
      "Epoch 125/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0951\n",
      "Epoch 126/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0925\n",
      "Epoch 127/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0902\n",
      "Epoch 128/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0882\n",
      "Epoch 129/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0862\n",
      "Epoch 130/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0840\n",
      "Epoch 131/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0826\n",
      "Epoch 132/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0807\n",
      "Epoch 133/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0790\n",
      "Epoch 134/200\n",
      "4/4 [==============================] - 0s 983us/step - loss: 0.0772\n",
      "Epoch 135/200\n",
      "4/4 [==============================] - 0s 1000us/step - loss: 0.0757\n",
      "Epoch 136/200\n",
      "4/4 [==============================] - 0s 995us/step - loss: 0.0741\n",
      "Epoch 137/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0728\n",
      "Epoch 138/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0714\n",
      "Epoch 139/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0700\n",
      "Epoch 140/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0685\n",
      "Epoch 141/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0670\n",
      "Epoch 142/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0657\n",
      "Epoch 143/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0645\n",
      "Epoch 144/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0634\n",
      "Epoch 145/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0622\n",
      "Epoch 146/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0611\n",
      "Epoch 147/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0601\n",
      "Epoch 148/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0591\n",
      "Epoch 149/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0582\n",
      "Epoch 150/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0575\n",
      "Epoch 151/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0570\n",
      "Epoch 152/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0563\n",
      "Epoch 153/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0553\n",
      "Epoch 154/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0541\n",
      "Epoch 155/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0530\n",
      "Epoch 156/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0519\n",
      "Epoch 157/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0510\n",
      "Epoch 158/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0502\n",
      "Epoch 159/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0496\n",
      "Epoch 160/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0490\n",
      "Epoch 161/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0481\n",
      "Epoch 162/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0473\n",
      "Epoch 163/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0466\n",
      "Epoch 164/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0458\n",
      "Epoch 165/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0452\n",
      "Epoch 166/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0445\n",
      "Epoch 167/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0440\n",
      "Epoch 168/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0434\n",
      "Epoch 169/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0429\n",
      "Epoch 170/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0423\n",
      "Epoch 171/200\n",
      "4/4 [==============================] - 0s 2ms/step - loss: 0.0418\n",
      "Epoch 172/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0414\n",
      "Epoch 173/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0413\n",
      "Epoch 174/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0408\n",
      "Epoch 175/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0401\n",
      "Epoch 176/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0393\n",
      "Epoch 177/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0388\n",
      "Epoch 178/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0384\n",
      "Epoch 179/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0384\n",
      "Epoch 180/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0377\n",
      "Epoch 181/200\n",
      "4/4 [==============================] - 0s 998us/step - loss: 0.0369\n",
      "Epoch 182/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0366\n",
      "Epoch 183/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0363\n",
      "Epoch 184/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0359\n",
      "Epoch 185/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0353\n",
      "Epoch 186/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0348\n",
      "Epoch 187/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0345\n",
      "Epoch 188/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0343\n",
      "Epoch 189/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0339\n",
      "Epoch 190/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0337\n",
      "Epoch 191/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0333\n",
      "Epoch 192/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0330\n",
      "Epoch 193/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0325\n",
      "Epoch 194/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0321\n",
      "Epoch 195/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0317\n",
      "Epoch 196/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0314\n",
      "Epoch 197/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0310\n",
      "Epoch 198/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0306\n",
      "Epoch 199/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0303\n",
      "Epoch 200/200\n",
      "4/4 [==============================] - 0s 1ms/step - loss: 0.0300\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f3ca47e9d90>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.compile(\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer=tf.keras.optimizers.Adam(0.01),\n",
    ")\n",
    "\n",
    "model.fit(\n",
    "    X_train,y_train,\n",
    "    epochs=200\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the model trained, we can see how the model has classified the training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e22967edb8dd44f696b06a09ce4a6eee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt_cat_mc(X_train, y_train, model, classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Above, the decision boundaries show how the model has partitioned the input space.  This very simple model has had no trouble classifying the training data. How did it accomplish this? Let's look at the network in more detail. \n",
    "\n",
    "Below, we will pull the trained weights from the model and use that to plot the function of each of the network units. Further down, there is a more detailed explanation of the results. You don't need to know these details to successfully use neural networks, but it may be helpful to gain more intuition about how the layers combine to solve a classification problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gather the trained parameters from the first layer\n",
    "l1 = model.get_layer(\"L1\")\n",
    "W1,b1 = l1.get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06a6008699124ea59de975e69a8bc354",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot the function of the first layer\n",
    "plt_layer_relu(X_train, y_train.reshape(-1,), W1, b1, classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1698c9780d8c41c7aca73763c21924c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# gather the trained parameters from the output layer\n",
    "l2 = model.get_layer(\"L2\")\n",
    "W2, b2 = l2.get_weights()\n",
    "# create the 'new features', the training examples after L1 transformation\n",
    "Xl2 = np.zeros_like(X_train)\n",
    "Xl2 = np.maximum(0, np.dot(X_train,W1) + b1)\n",
    "\n",
    "plt_output_layer_linear(Xl2, y_train.reshape(-1,), W2, b2, classes,\n",
    "                        x0_rng = (-0.25,np.amax(Xl2[:,0])), x1_rng = (-0.25,np.amax(Xl2[:,1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explanation\n",
    "#### Layer 1 <img align=\"Right\" src=\"./images/C2_W2_mclass_layer1.png\"  style=\" width:600px; padding: 10px 20px ; \">\n",
    "These plots show the function of Units 0 and 1 in the first layer of the network. The inputs are ($x_0,x_1$) on the axis. The output of the unit is represented by the color of the background. This is indicated by the color bar on the right of each graph. Notice that since these units are using a ReLu, the outputs do not necessarily fall between 0 and 1 and in this case are greater than 20 at their peaks. \n",
    "The contour lines in this graph show the transition point between the output, $a^{[1]}_j$ being zero and non-zero. Recall the graph for a ReLu :<img align=\"right\" src=\"./images/C2_W2_mclass_relu.png\"  style=\" width:200px; padding: 10px 20px ; \"> The contour line in the graph is the inflection point in the ReLu.\n",
    "\n",
    "Unit 0 has separated classes 0 and 1 from classes 2 and 3. Points to the left of the line (classes 0 and 1) will output zero, while points to the right will output a value greater than zero.  \n",
    "Unit 1 has separated classes 0 and 2 from classes 1 and 3. Points above the line (classes 0 and 2 ) will output a zero, while points below will output a value greater than zero. Let's see how this works out in the next layer!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Layer 2, the output layer  <img align=\"Right\" src=\"./images/C2_W2_mclass_layer2.png\"  style=\" width:600px; padding: 10px 20px ; \">\n",
    "\n",
    "The dots in these graphs are the training examples translated by the first layer. One way to think of this is the first layer has created a new set of features for evaluation by the 2nd layer. The axes in these plots are the outputs of the previous layer $a^{[1]}_0$ and $a^{[1]}_1$. As predicted above, classes 0 and 1 (green and blue) have  $a^{[1]}_0 = 0$ while classes 1 and 2 (blue and green) have $a^{[1]}_1 = 0$.  \n",
    "Once again, the intensity of the background color indicates the highest values.  \n",
    "Unit 0 will produce its maximum value for values near (0,0), where class 0 (blue) has been mapped.    \n",
    "Unit 1 produces its highest values in the upper left corner selecting class 1 (green).  \n",
    "Unit 2 targets the lower right corner where class 3 (orange) resides.  \n",
    "Unit 3 produces its highest values in the upper right selecting our final class (purple).  \n",
    "\n",
    "One other aspect that is not obvious from the graphs is that the values have been coordinated between the units. It is not sufficient for a unit to produce a maximum value for the class it is selecting for, it must also be the highest value of all the units for points in that class. This is done by the implied softmax function that is part of the loss function (`SparseCategoricalCrossEntropy`). Unlike other activation functions, the softmax works across all the outputs.\n",
    "\n",
    "You can successfully use neural networks without knowing the details of what each unit is up to. Hopefully, this example has provided some intuition about what is happening under the hood."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Congratulations!\n",
    "You have learned to build and operate a neural network for multiclass classification.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
